Generative Modeling¶
Lecture Date: 2025-03-14¶
Four models are detailed in this notebook: make sure to prepare the DataLoaders first, and then you can jump to any of the other model sections and run them. There is some overlap between the implementations, but each one utilizes uniques changes to the architecture and/or training regime for that particular approach.
- Prepare MNIST DataLoaders (do this first, then jump to a model) (MNIST Data Prep)
- Generative Adversarial Network (GAN)
- Auxiliary Classifier Adversarial Network (AC-GAN)
- Denoising Diffusion Probabilistic Model (DDPM)
- Conditioned Denoising Diffusion Probabilitistic Model (CDDPM)
import numpy as np
import torch
import torchvision
import torchmetrics
import lightning.pytorch as pl
from torchinfo import summary
from torchview import draw_graph
import matplotlib.pyplot as plt
import pandas as pd
# Minor tensor-core speedup
torch.set_float32_matmul_precision('medium')
class GenerativeMNISTDataModule(pl.LightningDataModule):
def __init__(self,
batch_size = 128,
num_workers = 4,
location = "~/datasets",
**kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.num_workers = num_workers
self.location = "~/datasets"
self.input_shape = (1,28,28)
self.output_shape = (1,28,28)
self.data_train = None
self.data_test = None
self.data_predict = None
def setup(self, stage: str):
if stage == 'fit' and not self.data_train:
training_dataset = torchvision.datasets.MNIST(root=self.location,download=True, train=True)
x_train = np.expand_dims(training_dataset.data,1)
x_train = (x_train / 127.5) - 1.0
y_train = np.array(training_dataset.targets)
self.data_train = list(zip(torch.Tensor(x_train).to(torch.float32),
torch.Tensor(y_train).to(torch.long)))
if (stage == 'test' or \
stage == 'predict') and \
not(self.data_test and self.data_predict):
testing_dataset = torchvision.datasets.MNIST(root=self.location,download=True, train=False)
x_test = np.expand_dims(testing_dataset.data,1)
x_test = (x_test / 127.5) - 1.0
y_test = np.array(testing_dataset.targets)
self.data_test = list(zip(torch.Tensor(x_test).to(torch.float32),
torch.Tensor(y_test).to(torch.long)))
self.data_predict = list(zip(torch.zeros(x_test[:80].shape).to(torch.float32),
torch.repeat_interleave(torch.arange(10),8).to(torch.long)))
def train_dataloader(self):
return torch.utils.data.DataLoader(self.data_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True)
# Doesn't currently exist
# def val_dataloader(self):
# return torch.utils.data.DataLoader(self.data_val,
# batch_size=self.batch_size,
# num_workers=self.num_workers,
# shuffle=False)
def test_dataloader(self):
return torch.utils.data.DataLoader(self.data_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)
def predict_dataloader(self):
return torch.utils.data.DataLoader(self.data_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)
data_module = GenerativeMNISTDataModule()
GAN¶
Adapted from: https://lightning.ai/docs/pytorch/2.1.0/notebooks/lightning_examples/basic-gan.html
The model below implements a Generative Adversarial Network (GAN) for learning to generate MNIST digits. The model is not significantly prone to mode collapse: this is largely avoided here by using a relatively large latent space size and residually-connected bottleneck convolution networks for both the generator and discriminator components. (The original code linked-to above would rarely converge and typically exhibited extreme mode collapse due to using only standard, non-residual layers).
class ResidualConv2dLayer(pl.LightningModule):
def __init__(self,
size,
**kwargs):
super().__init__(**kwargs)
self.padding = torch.nn.ZeroPad2d((1,1,1,1))
self.conv_layer = torch.nn.Conv2d(size,
size,
kernel_size=(3,3),
stride=(1,1))
self.norm = torch.nn.BatchNorm2d(size)
self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = x
y = self.padding(y)
y = self.conv_layer(y)
y = self.norm(y)
y = self.activation(y)
y = x + y
return y
class Generator(pl.LightningModule):
def __init__(self,
latent_dim,
n_channels = 1,
**kwargs):
super().__init__(**kwargs)
self.model = torch.nn.Sequential(*(
[torch.nn.Linear(latent_dim,latent_dim*4),
torch.nn.Unflatten(1,(latent_dim,2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2)] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2),
torch.nn.Conv2d(latent_dim,latent_dim,(2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2)] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2)] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Conv2d(latent_dim,n_channels,1),
torch.nn.Tanh()]
))
def forward(self, x):
y = x
y = self.model(y)
return y
temp_g = Generator(128)
summary(temp_g,
input_size=(5,128))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== Generator [5, 1, 28, 28] -- ├─Sequential: 1-1 [5, 1, 28, 28] -- │ └─Linear: 2-1 [5, 512] 66,048 │ └─Unflatten: 2-2 [5, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-3 [5, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-1 [5, 128, 4, 4] -- │ │ └─Conv2d: 3-2 [5, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-3 [5, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-4 [5, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-4 [5, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-5 [5, 128, 4, 4] -- │ │ └─Conv2d: 3-6 [5, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-7 [5, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-8 [5, 128, 2, 2] -- │ └─Upsample: 2-5 [5, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-6 [5, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-9 [5, 128, 6, 6] -- │ │ └─Conv2d: 3-10 [5, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-11 [5, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-12 [5, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-7 [5, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-13 [5, 128, 6, 6] -- │ │ └─Conv2d: 3-14 [5, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-15 [5, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-16 [5, 128, 4, 4] -- │ └─Upsample: 2-8 [5, 128, 8, 8] -- │ └─Conv2d: 2-9 [5, 128, 7, 7] 65,664 │ └─ResidualConv2dLayer: 2-10 [5, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-17 [5, 128, 9, 9] -- │ │ └─Conv2d: 3-18 [5, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-19 [5, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-20 [5, 128, 7, 7] -- │ └─ResidualConv2dLayer: 2-11 [5, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-21 [5, 128, 9, 9] -- │ │ └─Conv2d: 3-22 [5, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-23 [5, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-24 [5, 128, 7, 7] -- │ └─Upsample: 2-12 [5, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-13 [5, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-25 [5, 128, 16, 16] -- │ │ └─Conv2d: 3-26 [5, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-27 [5, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-28 [5, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-14 [5, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-29 [5, 128, 16, 16] -- │ │ └─Conv2d: 3-30 [5, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-31 [5, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-32 [5, 128, 14, 14] -- │ └─Upsample: 2-15 [5, 128, 28, 28] -- │ └─ResidualConv2dLayer: 2-16 [5, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-33 [5, 128, 30, 30] -- │ │ └─Conv2d: 3-34 [5, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-35 [5, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-36 [5, 128, 28, 28] -- │ └─ResidualConv2dLayer: 2-17 [5, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-37 [5, 128, 30, 30] -- │ │ └─Conv2d: 3-38 [5, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-39 [5, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-40 [5, 128, 28, 28] -- │ └─Conv2d: 2-18 [5, 1, 28, 28] 129 │ └─Tanh: 2-19 [5, 1, 28, 28] -- ========================================================================================== Total params: 1,610,241 Trainable params: 1,610,241 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 1.57 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 21.79 Params size (MB): 6.44 Estimated Total Size (MB): 28.23 ==========================================================================================
# model_graph = draw_graph(temp_g,
# input_size=(5,128),
# depth=3)
# model_graph.visual_graph
class Discriminator(pl.LightningModule):
def __init__(self,
latent_dim,
n_channels=1,
**kwargs):
super().__init__(**kwargs)
self.model = torch.nn.Sequential(*(
[torch.nn.Conv2d(n_channels,latent_dim,1)]+
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 28
[torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 14
[torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 7
[torch.nn.ZeroPad2d((1,1,1,1)), torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 4
[torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 2
[torch.nn.AvgPool2d((2,2)),
torch.nn.Flatten(),
torch.nn.Linear(latent_dim,1),
torch.nn.Sigmoid()]
))
def forward(self, x):
y = x
y = self.model(y)
return y
temp_d = Discriminator(128)
summary(temp_d,input_size=(data_module.batch_size,)+data_module.input_shape)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== Discriminator [128, 1] -- ├─Sequential: 1-1 [128, 1] -- │ └─Conv2d: 2-1 [128, 128, 28, 28] 256 │ └─ResidualConv2dLayer: 2-2 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-1 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-2 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-3 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-4 [128, 128, 28, 28] -- │ └─ResidualConv2dLayer: 2-3 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-5 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-6 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-7 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-8 [128, 128, 28, 28] -- │ └─AvgPool2d: 2-4 [128, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-5 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-9 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-10 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-11 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-12 [128, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-6 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-13 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-14 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-15 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-16 [128, 128, 14, 14] -- │ └─AvgPool2d: 2-7 [128, 128, 7, 7] -- │ └─ResidualConv2dLayer: 2-8 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-17 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-18 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-19 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-20 [128, 128, 7, 7] -- │ └─ResidualConv2dLayer: 2-9 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-21 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-22 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-23 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-24 [128, 128, 7, 7] -- │ └─ZeroPad2d: 2-10 [128, 128, 9, 9] -- │ └─AvgPool2d: 2-11 [128, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-12 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-25 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-26 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-27 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-28 [128, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-13 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-29 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-30 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-31 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-32 [128, 128, 4, 4] -- │ └─AvgPool2d: 2-14 [128, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-15 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-33 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-34 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-35 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-36 [128, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-16 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-37 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-38 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-39 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-40 [128, 128, 2, 2] -- │ └─AvgPool2d: 2-17 [128, 128, 1, 1] -- │ └─Flatten: 2-18 [128, 128] -- │ └─Linear: 2-19 [128, 1] 129 │ └─Sigmoid: 2-20 [128, 1] -- ========================================================================================== Total params: 1,478,785 Trainable params: 1,478,785 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 39.66 ========================================================================================== Input size (MB): 0.40 Forward/backward pass size (MB): 652.74 Params size (MB): 5.92 Estimated Total Size (MB): 659.06 ==========================================================================================
# model_graph = draw_graph(temp_d,
# input_size=(data_module.batch_size,)+data_module.input_shape,
# depth=3)
# model_graph.visual_graph
class GAN(pl.LightningModule):
def __init__(
self,
latent_dim: int = 128,
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = 256,
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters()
self.automatic_optimization = False
# networks
self.generator = Generator(latent_dim=self.hparams.latent_dim)
self.discriminator = Discriminator(latent_dim=self.hparams.latent_dim)
def forward(self, z):
return self.generator(z)
def adversarial_loss(self, y_hat, y):
return torch.nn.functional.binary_cross_entropy(y_hat, y)
def predict_step(self, batch):
zeros, y_true = batch
z = torch.randn(zeros.shape[0], model.hparams.latent_dim).type_as(zeros)
imgs = (self(z).permute((0,2,3,1))+1.0)/2.0
return imgs
def training_step(self, batch):
imgs, _ = batch
optimizer_g, optimizer_d = self.optimizers()
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)
# train generator
# generate images
self.toggle_optimizer(optimizer_g)
self.generated_imgs = self(z)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)*0.5
valid = valid.type_as(imgs)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)
# train discriminator
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)
def configure_optimizers(self):
lr = self.hparams.lr
b1 = self.hparams.b1
b2 = self.hparams.b2
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []
model = GAN(128)
# model = GAN.load_from_checkpoint("logs/gan/mnist/checkpoints/epoch=9-step=9380.ckpt")
Before Training...¶
logger = pl.loggers.CSVLogger("logs",
name="gan",
version="mnist")
trainer = pl.Trainer(
accelerator="auto",
devices=1,
max_epochs=10,
logger=logger,
callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)]
)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs
imgs = trainer.predict(model, data_module)[0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] SLURM auto-requeueing enabled. Setting signal handlers. /opt/conda/lib/python3.12/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(
Predicting: | | 0/? [00:00<?, ?it/s]
imgs.shape
torch.Size([80, 28, 28, 1])
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
plt.subplot(10, 8, i+1)
plt.imshow(imgs[i], cmap="gray")
plt.axis("off")
plt.show()
# Train the GAN
trainer.fit(model, data_module)
results = pd.read_csv("logs/gan/mnist/metrics.csv",sep=",")
plt.plot(results["epoch"][np.logical_not(np.isnan(results["g_loss"]))],
results["g_loss"][np.logical_not(np.isnan(results["g_loss"]))],
label="Loss")
plt.ylabel('Generator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
plt.plot(results["epoch"][np.logical_not(np.isnan(results["d_loss"]))],
results["d_loss"][np.logical_not(np.isnan(results["d_loss"]))],
label="Loss")
plt.ylabel('Discriminator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
After Training...¶
img = trainer.predict(model, data_module)[0]
Predicting: | | 0/? [00:00<?, ?it/s]
imgs.shape
torch.Size([80, 28, 28, 1])
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
plt.subplot(10, 8, i+1)
plt.imshow(imgs[i], cmap="gray")
plt.axis("off")
plt.show()
AC-GAN¶
Inspired from: https://lightning.ai/docs/pytorch/2.1.0/notebooks/lightning_examples/basic-gan.html
This version prevents mode collapse by using the class labels to creat an Auxiliary Classifier Generative Adversarial Network (AC-GAN). The discriminator can predict 11 classes: real digits 0 to 9 and class 10 for fakes (of any kind). The generator is provided an additional integer input, 0 to 9, to indicate which of the digits it should be generating (in addition to the random noise input vector). Mode collapse is therefore largely avoided since all 10 digit types can be requested and must be able to fool the discriminator. Thus, human supervision can help to avoid mode collapse, but requires labeled data sets.
class ResidualConv2dLayer(pl.LightningModule):
def __init__(self,
size,
**kwargs):
super().__init__(**kwargs)
self.padding = torch.nn.ZeroPad2d((1,1,1,1))
self.conv_layer = torch.nn.Conv2d(size,
size,
kernel_size=(3,3),
stride=(1,1))
self.norm = torch.nn.BatchNorm2d(size)
self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = x
y = self.padding(y)
y = self.conv_layer(y)
y = self.norm(y)
y = self.activation(y)
y = x + y
return y
class Generator(pl.LightningModule):
def __init__(self,
latent_dim,
n_channels = 1,
n_classes = 10,
**kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.embedding = torch.nn.Embedding(n_classes,latent_dim)
self.model = torch.nn.Sequential(*(
[torch.nn.Linear(latent_dim,latent_dim*4),
torch.nn.Unflatten(1,(latent_dim,2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2)] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2),
torch.nn.Conv2d(latent_dim,latent_dim,(2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2)] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Upsample(scale_factor=2)] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] +
[torch.nn.Conv2d(latent_dim,n_channels,1),
torch.nn.Tanh()]
))
def forward(self, x, c):
y = x * self.embedding(c)
y = self.model(y)
return y
temp_g = Generator(128)
summary(temp_g,
input_size=[(5,128),(5,)],
dtypes=[torch.float32,torch.long])
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== Generator [5, 1, 28, 28] -- ├─Embedding: 1-1 [5, 128] 1,280 ├─Sequential: 1-2 [5, 1, 28, 28] -- │ └─Linear: 2-1 [5, 512] 66,048 │ └─Unflatten: 2-2 [5, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-3 [5, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-1 [5, 128, 4, 4] -- │ │ └─Conv2d: 3-2 [5, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-3 [5, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-4 [5, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-4 [5, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-5 [5, 128, 4, 4] -- │ │ └─Conv2d: 3-6 [5, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-7 [5, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-8 [5, 128, 2, 2] -- │ └─Upsample: 2-5 [5, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-6 [5, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-9 [5, 128, 6, 6] -- │ │ └─Conv2d: 3-10 [5, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-11 [5, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-12 [5, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-7 [5, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-13 [5, 128, 6, 6] -- │ │ └─Conv2d: 3-14 [5, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-15 [5, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-16 [5, 128, 4, 4] -- │ └─Upsample: 2-8 [5, 128, 8, 8] -- │ └─Conv2d: 2-9 [5, 128, 7, 7] 65,664 │ └─ResidualConv2dLayer: 2-10 [5, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-17 [5, 128, 9, 9] -- │ │ └─Conv2d: 3-18 [5, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-19 [5, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-20 [5, 128, 7, 7] -- │ └─ResidualConv2dLayer: 2-11 [5, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-21 [5, 128, 9, 9] -- │ │ └─Conv2d: 3-22 [5, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-23 [5, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-24 [5, 128, 7, 7] -- │ └─Upsample: 2-12 [5, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-13 [5, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-25 [5, 128, 16, 16] -- │ │ └─Conv2d: 3-26 [5, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-27 [5, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-28 [5, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-14 [5, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-29 [5, 128, 16, 16] -- │ │ └─Conv2d: 3-30 [5, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-31 [5, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-32 [5, 128, 14, 14] -- │ └─Upsample: 2-15 [5, 128, 28, 28] -- │ └─ResidualConv2dLayer: 2-16 [5, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-33 [5, 128, 30, 30] -- │ │ └─Conv2d: 3-34 [5, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-35 [5, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-36 [5, 128, 28, 28] -- │ └─ResidualConv2dLayer: 2-17 [5, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-37 [5, 128, 30, 30] -- │ │ └─Conv2d: 3-38 [5, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-39 [5, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-40 [5, 128, 28, 28] -- │ └─Conv2d: 2-18 [5, 1, 28, 28] 129 │ └─Tanh: 2-19 [5, 1, 28, 28] -- ========================================================================================== Total params: 1,611,521 Trainable params: 1,611,521 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 1.57 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 21.79 Params size (MB): 6.45 Estimated Total Size (MB): 28.24 ==========================================================================================
# model_graph = draw_graph(temp_g,
# input_size=[(5,128),(5,)],
# dtypes=[torch.float32,torch.long],
# depth=3)
# model_graph.visual_graph
class Discriminator(pl.LightningModule):
def __init__(self,
latent_dim,
n_channels=1,
n_classes=10):
super().__init__()
self.model = torch.nn.Sequential(*(
[torch.nn.Conv2d(n_channels,latent_dim,1)]+
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 28
[torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 14
[torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 7
[torch.nn.ZeroPad2d((1,1,1,1)), torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 4
[torch.nn.AvgPool2d((2,2))] +
[ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 2
[torch.nn.AvgPool2d((2,2)),
torch.nn.Flatten(),
torch.nn.Linear(latent_dim,n_classes+1)]
))
def forward(self, x):
y = x
y = self.model(y)
return y
temp_d = Discriminator(128)
summary(temp_d,input_size=(data_module.batch_size,)+data_module.input_shape)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== Discriminator [128, 11] -- ├─Sequential: 1-1 [128, 11] -- │ └─Conv2d: 2-1 [128, 128, 28, 28] 256 │ └─ResidualConv2dLayer: 2-2 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-1 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-2 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-3 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-4 [128, 128, 28, 28] -- │ └─ResidualConv2dLayer: 2-3 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-5 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-6 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-7 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-8 [128, 128, 28, 28] -- │ └─AvgPool2d: 2-4 [128, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-5 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-9 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-10 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-11 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-12 [128, 128, 14, 14] -- │ └─ResidualConv2dLayer: 2-6 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-13 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-14 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-15 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-16 [128, 128, 14, 14] -- │ └─AvgPool2d: 2-7 [128, 128, 7, 7] -- │ └─ResidualConv2dLayer: 2-8 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-17 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-18 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-19 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-20 [128, 128, 7, 7] -- │ └─ResidualConv2dLayer: 2-9 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-21 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-22 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-23 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-24 [128, 128, 7, 7] -- │ └─ZeroPad2d: 2-10 [128, 128, 9, 9] -- │ └─AvgPool2d: 2-11 [128, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-12 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-25 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-26 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-27 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-28 [128, 128, 4, 4] -- │ └─ResidualConv2dLayer: 2-13 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-29 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-30 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-31 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-32 [128, 128, 4, 4] -- │ └─AvgPool2d: 2-14 [128, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-15 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-33 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-34 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-35 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-36 [128, 128, 2, 2] -- │ └─ResidualConv2dLayer: 2-16 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-37 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-38 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-39 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-40 [128, 128, 2, 2] -- │ └─AvgPool2d: 2-17 [128, 128, 1, 1] -- │ └─Flatten: 2-18 [128, 128] -- │ └─Linear: 2-19 [128, 11] 1,419 ========================================================================================== Total params: 1,480,075 Trainable params: 1,480,075 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 39.66 ========================================================================================== Input size (MB): 0.40 Forward/backward pass size (MB): 652.75 Params size (MB): 5.92 Estimated Total Size (MB): 659.07 ==========================================================================================
# model_graph = draw_graph(temp_d,
# input_size=(data_module.batch_size,)+data_module.input_shape,
# depth=3)
# model_graph.visual_graph
class ACGAN(pl.LightningModule):
def __init__(
self,
latent_dim: int = 128,
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = 256,
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters()
self.automatic_optimization = False
# networks
self.generator = Generator(latent_dim=self.hparams.latent_dim)
self.discriminator = Discriminator(latent_dim=self.hparams.latent_dim)
def forward(self, z, c):
return self.generator(z, c)
def adversarial_loss(self, y_hat, y):
return torch.nn.functional.cross_entropy(y_hat, y)
def predict_step(self, batch):
zeros, y_true = batch
z = torch.randn(zeros.shape[0], model.hparams.latent_dim).type_as(zeros)
imgs = (self(z,y_true).permute((0,2,3,1))+1.0)/2.0
return imgs
def training_step(self, batch):
imgs, y_true = batch
optimizer_g, optimizer_d = self.optimizers()
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)
c = torch.randint(10,y_true.shape,dtype=torch.long)
c = c.type_as(y_true)
# train generator
# generate images
self.toggle_optimizer(optimizer_g)
self.generated_imgs = self(z,c)
# adversarial loss is categorical cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z,c)), c)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)
# train discriminator
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)
# how well can it label as real?
real_loss = self.adversarial_loss(self.discriminator(imgs), y_true)
# how well can it label as fake?
fake = torch.ones(imgs.size(0), dtype=torch.long)*10
fake = fake.type_as(y_true)
fake_loss = self.adversarial_loss(self.discriminator(self(z,c).detach()), fake)
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)
def configure_optimizers(self):
lr = self.hparams.lr
b1 = self.hparams.b1
b2 = self.hparams.b2
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []
model = ACGAN(128)
# model = ACGAN.load_from_checkpoint("logs/ac-gan/mnist/checkpoints/epoch=9-step=9380.ckpt")
Before Training...¶
logger = pl.loggers.CSVLogger("logs",
name="ac-gan",
version="mnist")
trainer = pl.Trainer(
accelerator="auto",
devices=1,
max_epochs=10,
logger=logger,
callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)]
)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs
imgs = trainer.predict(model, data_module)[0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] SLURM auto-requeueing enabled. Setting signal handlers. /opt/conda/lib/python3.12/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(
Predicting: | | 0/? [00:00<?, ?it/s]
imgs.shape
torch.Size([80, 28, 28, 1])
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
plt.subplot(10, 8, i+1)
plt.imshow(imgs[i], cmap="gray")
plt.axis("off")
plt.show()
trainer.fit(model, data_module)
results = pd.read_csv("logs/ac-gan/mnist/metrics.csv",sep=",")
plt.plot(results["epoch"][np.logical_not(np.isnan(results["g_loss"]))],
results["g_loss"][np.logical_not(np.isnan(results["g_loss"]))],
label="Loss")
plt.ylabel('Generator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
plt.plot(results["epoch"][np.logical_not(np.isnan(results["d_loss"]))],
results["d_loss"][np.logical_not(np.isnan(results["d_loss"]))],
label="Loss")
plt.ylabel('Discriminator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
After Training...¶
imgs = trainer.predict(model, data_module)[0]
Predicting: | | 0/? [00:00<?, ?it/s]
imgs.shape
torch.Size([80, 28, 28, 1])
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
plt.subplot(10, 8, i+1)
plt.imshow(imgs[i], cmap="gray")
plt.axis("off")
plt.show()
DDPM¶
Denoising Diffusion Probabilistic Model¶
Adapted from: https://github.com/tcapelle/Diffusion-Models-pytorch/
The model below implements a Denoising Diffusion Probabilistic Model (DDPM) for learning to generate MNIST digits. Given the nature of stable diffusion training, the model is not significantly prone to mode collapse, partially due to the relatively large latent space size and residually-connected U net convolution network used, but mainly due to the denoising training regime of DDPMs. Much of the training and prediction steps were inspired by the code linked-to above, but everything was ported over to PyTorch Lightning, the noise-prediction architecture is based more on the GAN code above (using the principles of U net construction), a more standardized version of position encoding to provide the time-conditioning input was used, and MeanLogCoshError (MLCE) loss is used instead of MeanSquaredError (MSE) loss to make for slightly more stable training in general.
class SinePositionEncoding(pl.LightningModule):
def __init__(self,
dim,
max_wavelength=10000.0,
**kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_wavelength = torch.Tensor([max_wavelength])
def forward(self, x):
hidden_size = self.dim
position = x
min_freq = (1 / self.max_wavelength).type_as(x)
timescales = torch.pow(
min_freq,
(2 * (torch.arange(hidden_size) // 2)).type_as(x)
/ torch.Tensor([hidden_size]).type_as(x)
)
angles = torch.unsqueeze(position, 1) * torch.unsqueeze(timescales, 0)
cos_mask = (torch.arange(hidden_size) % 2).type_as(x)
sin_mask = 1 - cos_mask
positional_encodings = (
torch.sin(angles) * sin_mask + torch.cos(angles) * cos_mask
)
return positional_encodings.type_as(x)
class TimeConditionedResidualConv2dLayer(pl.LightningModule):
def __init__(self,
size,
t_dim = 256,
**kwargs):
super().__init__(**kwargs)
self.padding = torch.nn.ZeroPad2d((1,1,1,1))
self.conv_layer = torch.nn.Conv2d(size,
size,
kernel_size=(3,3),
stride=(1,1))
self.norm = torch.nn.BatchNorm2d(size)
self.embed_t = torch.nn.Linear(t_dim,size)
# self.embed_t = torch.nn.Sequential(*[torch.nn.SiLU(),torch.nn.Linear(t_dim,size)])
self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
x, t = x
y = x
y = self.padding(y)
y = self.conv_layer(y)
y = self.norm(y)
y = self.activation(y)
y = y + x + self.embed_t(t)[:,:,None,None].repeat(1, 1, x.shape[-2], x.shape[-1])
return y, t
class TimeConditionedUNet(pl.LightningModule):
def __init__(self,
latent_dim,
t_dim = 256,
n_channels = 1,
depth = 5,
**kwargs):
super().__init__(**kwargs)
# self.t_dim = t_dim
self.time_encoding = SinePositionEncoding(t_dim)
self.encoder0 = torch.nn.Conv2d(n_channels,latent_dim,1)
self.encoder1 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 28
self.encoder2 = torch.nn.AvgPool2d((2,2))
self.encoder3 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 14
self.encoder4 = torch.nn.AvgPool2d((2,2))
self.encoder5 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 7
self.encoder6 = torch.nn.ZeroPad2d((1,1,1,1))
self.encoder7 = torch.nn.AvgPool2d((2,2))
self.encoder8 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 4
self.encoder9 = torch.nn.AvgPool2d((2,2))
self.encoder10 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 2
self.encoder11 = torch.nn.AvgPool2d((2,2))
self.encoder12 = torch.nn.Flatten()
self.decoder0 = torch.nn.Linear(latent_dim,latent_dim*4)
self.decoder1 = torch.nn.Unflatten(1,(latent_dim,2,2))
self.decoder2 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(2)])
self.decoder3 = torch.nn.Upsample(scale_factor=2)
self.decoder4 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(2)])
self.decoder5 = torch.nn.Upsample(scale_factor=2)
self.decoder6 = torch.nn.Conv2d(latent_dim,latent_dim,(2,2))
self.decoder7 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(2)])
self.decoder8 = torch.nn.Upsample(scale_factor=2)
self.decoder9 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim) for _ in range(2)])
self.decoder10 = torch.nn.Upsample(scale_factor=2)
self.decoder11 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim) for _ in range(2)])
self.decoder12 = torch.nn.Conv2d(latent_dim,n_channels,1)
def time_embedding(self, t):
channels = self.t_dim
inv_freq = (1.0 / (
10000
** (torch.arange(0, channels, 2).float() / channels)
)).type_as(t)
time_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
time_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
time_enc = torch.cat([time_enc_a, time_enc_b], dim=-1)
return time_enc
def forward(self, x, t):
t = self.time_encoding(t.type_as(x))
# t = self.time_embedding(t.unsqueeze(-1).type_as(x))
y = y0 = x
y = self.encoder0(y)
y = y1 = self.encoder1((y,t))[0]
y = self.encoder2(y)
y = y3 = self.encoder3((y,t))[0]
y = self.encoder4(y)
y = y5 = self.encoder5((y,t))[0]
y = self.encoder6(y)
y = self.encoder7(y)
y = y8 = self.encoder8((y,t))[0]
y = self.encoder9(y)
y = y10 = self.encoder10((y,t))[0]
y = self.encoder11(y)
y = self.encoder12(y)
y = self.decoder0(y)
y = self.decoder1(y)
y = self.decoder2((y,t))[0] + y10
y = self.decoder3(y)
y = self.decoder4((y,t))[0] + y8
y = self.decoder5(y)
y = self.decoder6(y)
y = self.decoder7((y,t))[0] + y5
y = self.decoder8(y)
y = self.decoder9((y,t))[0] + y3
y = self.decoder10(y)
y = self.decoder11((y,t))[0] + y1
y = self.decoder12(y) + y0
return y
temp_u = TimeConditionedUNet(128)
summary(temp_u,input_size=[(data_module.batch_size,)+data_module.input_shape,
(data_module.batch_size,)])
========================================================================================================= Layer (type:depth-idx) Output Shape Param # ========================================================================================================= TimeConditionedUNet [128, 1, 28, 28] -- ├─SinePositionEncoding: 1-1 [128, 256] -- ├─Conv2d: 1-2 [128, 128, 28, 28] 256 ├─Sequential: 1-3 [128, 128, 28, 28] -- │ └─TimeConditionedResidualConv2dLayer: 2-1 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-1 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-2 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-3 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-4 [128, 128, 28, 28] -- │ │ └─Linear: 3-5 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-2 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-6 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-7 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-8 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-9 [128, 128, 28, 28] -- │ │ └─Linear: 3-10 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-3 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-11 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-12 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-13 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-14 [128, 128, 28, 28] -- │ │ └─Linear: 3-15 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-4 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-16 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-17 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-18 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-19 [128, 128, 28, 28] -- │ │ └─Linear: 3-20 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-5 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-21 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-22 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-23 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-24 [128, 128, 28, 28] -- │ │ └─Linear: 3-25 [128, 128] 32,896 ├─AvgPool2d: 1-4 [128, 128, 14, 14] -- ├─Sequential: 1-5 [128, 128, 14, 14] -- │ └─TimeConditionedResidualConv2dLayer: 2-6 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-26 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-27 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-28 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-29 [128, 128, 14, 14] -- │ │ └─Linear: 3-30 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-7 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-31 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-32 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-33 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-34 [128, 128, 14, 14] -- │ │ └─Linear: 3-35 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-8 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-36 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-37 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-38 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-39 [128, 128, 14, 14] -- │ │ └─Linear: 3-40 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-9 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-41 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-42 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-43 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-44 [128, 128, 14, 14] -- │ │ └─Linear: 3-45 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-10 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-46 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-47 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-48 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-49 [128, 128, 14, 14] -- │ │ └─Linear: 3-50 [128, 128] 32,896 ├─AvgPool2d: 1-6 [128, 128, 7, 7] -- ├─Sequential: 1-7 [128, 128, 7, 7] -- │ └─TimeConditionedResidualConv2dLayer: 2-11 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-51 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-52 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-53 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-54 [128, 128, 7, 7] -- │ │ └─Linear: 3-55 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-12 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-56 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-57 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-58 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-59 [128, 128, 7, 7] -- │ │ └─Linear: 3-60 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-13 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-61 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-62 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-63 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-64 [128, 128, 7, 7] -- │ │ └─Linear: 3-65 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-14 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-66 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-67 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-68 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-69 [128, 128, 7, 7] -- │ │ └─Linear: 3-70 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-15 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-71 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-72 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-73 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-74 [128, 128, 7, 7] -- │ │ └─Linear: 3-75 [128, 128] 32,896 ├─ZeroPad2d: 1-8 [128, 128, 9, 9] -- ├─AvgPool2d: 1-9 [128, 128, 4, 4] -- ├─Sequential: 1-10 [128, 128, 4, 4] -- │ └─TimeConditionedResidualConv2dLayer: 2-16 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-76 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-77 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-78 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-79 [128, 128, 4, 4] -- │ │ └─Linear: 3-80 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-17 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-81 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-82 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-83 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-84 [128, 128, 4, 4] -- │ │ └─Linear: 3-85 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-18 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-86 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-87 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-88 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-89 [128, 128, 4, 4] -- │ │ └─Linear: 3-90 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-19 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-91 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-92 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-93 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-94 [128, 128, 4, 4] -- │ │ └─Linear: 3-95 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-20 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-96 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-97 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-98 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-99 [128, 128, 4, 4] -- │ │ └─Linear: 3-100 [128, 128] 32,896 ├─AvgPool2d: 1-11 [128, 128, 2, 2] -- ├─Sequential: 1-12 [128, 128, 2, 2] -- │ └─TimeConditionedResidualConv2dLayer: 2-21 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-101 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-102 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-103 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-104 [128, 128, 2, 2] -- │ │ └─Linear: 3-105 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-22 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-106 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-107 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-108 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-109 [128, 128, 2, 2] -- │ │ └─Linear: 3-110 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-23 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-111 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-112 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-113 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-114 [128, 128, 2, 2] -- │ │ └─Linear: 3-115 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-24 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-116 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-117 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-118 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-119 [128, 128, 2, 2] -- │ │ └─Linear: 3-120 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-25 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-121 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-122 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-123 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-124 [128, 128, 2, 2] -- │ │ └─Linear: 3-125 [128, 128] 32,896 ├─AvgPool2d: 1-13 [128, 128, 1, 1] -- ├─Flatten: 1-14 [128, 128] -- ├─Linear: 1-15 [128, 512] 66,048 ├─Unflatten: 1-16 [128, 128, 2, 2] -- ├─Sequential: 1-17 [128, 128, 2, 2] -- │ └─TimeConditionedResidualConv2dLayer: 2-26 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-126 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-127 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-128 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-129 [128, 128, 2, 2] -- │ │ └─Linear: 3-130 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-27 [128, 128, 2, 2] -- │ │ └─ZeroPad2d: 3-131 [128, 128, 4, 4] -- │ │ └─Conv2d: 3-132 [128, 128, 2, 2] 147,584 │ │ └─BatchNorm2d: 3-133 [128, 128, 2, 2] 256 │ │ └─LeakyReLU: 3-134 [128, 128, 2, 2] -- │ │ └─Linear: 3-135 [128, 128] 32,896 ├─Upsample: 1-18 [128, 128, 4, 4] -- ├─Sequential: 1-19 [128, 128, 4, 4] -- │ └─TimeConditionedResidualConv2dLayer: 2-28 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-136 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-137 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-138 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-139 [128, 128, 4, 4] -- │ │ └─Linear: 3-140 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-29 [128, 128, 4, 4] -- │ │ └─ZeroPad2d: 3-141 [128, 128, 6, 6] -- │ │ └─Conv2d: 3-142 [128, 128, 4, 4] 147,584 │ │ └─BatchNorm2d: 3-143 [128, 128, 4, 4] 256 │ │ └─LeakyReLU: 3-144 [128, 128, 4, 4] -- │ │ └─Linear: 3-145 [128, 128] 32,896 ├─Upsample: 1-20 [128, 128, 8, 8] -- ├─Conv2d: 1-21 [128, 128, 7, 7] 65,664 ├─Sequential: 1-22 [128, 128, 7, 7] -- │ └─TimeConditionedResidualConv2dLayer: 2-30 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-146 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-147 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-148 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-149 [128, 128, 7, 7] -- │ │ └─Linear: 3-150 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-31 [128, 128, 7, 7] -- │ │ └─ZeroPad2d: 3-151 [128, 128, 9, 9] -- │ │ └─Conv2d: 3-152 [128, 128, 7, 7] 147,584 │ │ └─BatchNorm2d: 3-153 [128, 128, 7, 7] 256 │ │ └─LeakyReLU: 3-154 [128, 128, 7, 7] -- │ │ └─Linear: 3-155 [128, 128] 32,896 ├─Upsample: 1-23 [128, 128, 14, 14] -- ├─Sequential: 1-24 [128, 128, 14, 14] -- │ └─TimeConditionedResidualConv2dLayer: 2-32 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-156 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-157 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-158 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-159 [128, 128, 14, 14] -- │ │ └─Linear: 3-160 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-33 [128, 128, 14, 14] -- │ │ └─ZeroPad2d: 3-161 [128, 128, 16, 16] -- │ │ └─Conv2d: 3-162 [128, 128, 14, 14] 147,584 │ │ └─BatchNorm2d: 3-163 [128, 128, 14, 14] 256 │ │ └─LeakyReLU: 3-164 [128, 128, 14, 14] -- │ │ └─Linear: 3-165 [128, 128] 32,896 ├─Upsample: 1-25 [128, 128, 28, 28] -- ├─Sequential: 1-26 [128, 128, 28, 28] -- │ └─TimeConditionedResidualConv2dLayer: 2-34 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-166 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-167 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-168 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-169 [128, 128, 28, 28] -- │ │ └─Linear: 3-170 [128, 128] 32,896 │ └─TimeConditionedResidualConv2dLayer: 2-35 [128, 128, 28, 28] -- │ │ └─ZeroPad2d: 3-171 [128, 128, 30, 30] -- │ │ └─Conv2d: 3-172 [128, 128, 28, 28] 147,584 │ │ └─BatchNorm2d: 3-173 [128, 128, 28, 28] 256 │ │ └─LeakyReLU: 3-174 [128, 128, 28, 28] -- │ │ └─Linear: 3-175 [128, 128] 32,896 ├─Conv2d: 1-27 [128, 1, 28, 28] 129 ========================================================================================================= Total params: 6,457,857 Trainable params: 6,457,857 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 139.32 ========================================================================================================= Input size (MB): 0.40 Forward/backward pass size (MB): 2040.02 Params size (MB): 25.83 Estimated Total Size (MB): 2066.25 =========================================================================================================
# This kind of UNet is large and difficult to view.
# The PNG version is also too fuzzy (bad enough for
# the GAN components above). Running the code will
# allow for a better visualization.
# model_graph = draw_graph(temp_u,
# input_size=[(data_module.batch_size,)+data_module.input_shape,
# (data_module.batch_size,)],
# depth=3)
# model_graph.visual_graph